import itertools
import json
import os
from collections import Counter

import numpy as np
import trimesh
from matplotlib.path import Path
from sklearn.neighbors import KNeighborsClassifier


def load_segmentation(path, shape):
    """
    Get a segmentation mask for a given image
    Arguments:
        path: path to the segmentation json file
        shape: shape of the output mask
    Returns:
        Returns a segmentation mask
    """
    with open(path) as json_file:
        dict = json.load(json_file)
        segmentations = []
        for key, val in dict.items():
            if not key.startswith("item"):
                continue

            # Each item can have multiple polygons. Combine them to one
            # segmentation_coord = list(itertools.chain.from_iterable(val['segmentation']))
            # segmentation_coord = np.round(np.array(segmentation_coord)).astype(int)

            coordinates = []
            for segmentation_coord in val["segmentation"]:
                # The format before is [x1,y1, x2, y2, ....]
                x = segmentation_coord[::2]
                y = segmentation_coord[1::2]
                xy = np.vstack((x, y)).T
                coordinates.append(xy)

            segmentations.append({
                "type": val["category_name"],
                "type_id": val["category_id"],
                "coordinates": coordinates,
            })

        return segmentations


def smpl_to_recon_labels(recon, smpl, k=1):
    """
    Get the bodypart labels for the recon object by using the labels from the corresponding smpl object
    Arguments:
        recon: trimesh object (fully clothed model)
        shape: trimesh object (smpl model)
        k: number of nearest neighbours to use
    Returns:
        Returns a dictionary containing the bodypart and the corresponding indices
    """
    smpl_vert_segmentation = json.load(
        open(os.path.join(os.path.dirname(__file__), "smpl_vert_segmentation.json"))
    )
    n = smpl.vertices.shape[0]
    y = np.array([None] * n)
    for key, val in smpl_vert_segmentation.items():
        y[val] = key

    classifier = KNeighborsClassifier(n_neighbors=1)
    classifier.fit(smpl.vertices, y)

    y_pred = classifier.predict(recon.vertices)

    recon_labels = {}
    for key in smpl_vert_segmentation.keys():
        recon_labels[key] = list(np.argwhere(y_pred == key).flatten().astype(int))

    return recon_labels


def extract_cloth(recon, segmentation, K, R, t, smpl=None):
    """
    Extract a portion of a mesh using 2d segmentation coordinates
    Arguments:
        recon: fully clothed mesh
        seg_coord: segmentation coordinates in 2D (NDC)
        K: intrinsic matrix of the projection
        R: rotation matrix of the projection
        t: translation vector of the projection
    Returns:
        Returns a submesh using the segmentation coordinates
    """
    seg_coord = segmentation["coord_normalized"]
    mesh = trimesh.Trimesh(recon.vertices, recon.faces)
    extrinsic = np.zeros((3, 4))
    extrinsic[:3, :3] = R
    extrinsic[:, 3] = t
    P = K[:3, :3] @ extrinsic

    P_inv = np.linalg.pinv(P)

    # Each segmentation can contain multiple polygons
    # We need to check them separately
    points_so_far = []
    faces = recon.faces
    for polygon in seg_coord:
        n = len(polygon)
        coords_h = np.hstack((polygon, np.ones((n, 1))))
        # Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates
        XYZ = P_inv @ coords_h[:, :, None]
        XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1]))
        XYZ = XYZ[:, :3] / XYZ[:, 3, None]

        p = Path(XYZ[:, :2])

        grid = p.contains_points(recon.vertices[:, :2])
        indeces = np.argwhere(grid == True)
        points_so_far += list(indeces.flatten())

    if smpl is not None:
        num_verts = recon.vertices.shape[0]
        recon_labels = smpl_to_recon_labels(recon, smpl)
        body_parts_to_remove = [
            "rightHand",
            "leftToeBase",
            "leftFoot",
            "rightFoot",
            "head",
            "leftHandIndex1",
            "rightHandIndex1",
            "rightToeBase",
            "leftHand",
            "rightHand",
        ]
        type = segmentation["type_id"]

        # Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso)
        # https://github.com/switchablenorms/DeepFashion2
        # Short sleeve clothes
        if type == 1 or type == 3 or type == 10:
            body_parts_to_remove += ["leftForeArm", "rightForeArm"]
        # No sleeves at all or lower body clothes
        elif (type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9):
            body_parts_to_remove += [
                "leftForeArm",
                "rightForeArm",
                "leftArm",
                "rightArm",
            ]
        # Shorts
        elif type == 7:
            body_parts_to_remove += [
                "leftLeg",
                "rightLeg",
                "leftForeArm",
                "rightForeArm",
                "leftArm",
                "rightArm",
            ]

        verts_to_remove = list(
            itertools.chain.from_iterable([recon_labels[part] for part in body_parts_to_remove])
        )

        label_mask = np.zeros(num_verts, dtype=bool)
        label_mask[verts_to_remove] = True

        seg_mask = np.zeros(num_verts, dtype=bool)
        seg_mask[points_so_far] = True

        # Remove points that belong to other bodyparts
        # If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed
        extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask))

        combine_mask = np.zeros(num_verts, dtype=bool)
        combine_mask[points_so_far] = True
        combine_mask[extra_verts_to_remove] = False

        all_indices = np.argwhere(combine_mask == True).flatten()

    i_x = np.where(np.in1d(faces[:, 0], all_indices))[0]
    i_y = np.where(np.in1d(faces[:, 1], all_indices))[0]
    i_z = np.where(np.in1d(faces[:, 2], all_indices))[0]

    faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z)))
    mask = np.zeros(len(recon.faces), dtype=bool)
    if len(faces_to_keep) > 0:
        mask[faces_to_keep] = True

        mesh.update_faces(mask)
        mesh.remove_unreferenced_vertices()

        # mesh.rezero()

        return mesh

    return None
